{% comment %}
// vim: set syntax=asm :

/* mmm 64 x 1

    ymm0
    ymm1
    ymm2
    ymm3

System V ABI:
    args: rdi, rsi, rdx, rcx, r8, r9
    preserve: rbx, rsp, rbp, r12, r13, r14, r15
    scratch: rax, rdi, rsi, rdx, rcx, r8, r9, r10, r11
    return: rax (+rdx)

Windows ABI:
    args: RCX, RDX, R8, R9
    preserve: RBX, RBP, RDI, RSI, RSP, R12, R13, R14, R15, and XMM6-15
    scratch: RAX, RCX, RDX, R8, R9, R10, R11, XMM0-5, and the upper portions of YMM0-15 and ZMM0-15
    return: rax (+rdx)
*/
{% endcomment %}

{% include "preamble.tmpliq" type:"f32", size:"32x1", suffix:suffix, G:G %}

{{L}}clear:
    vzeroall
    jmp     {{L}}non_linear_loop

{{L}}add_mat_mul:
    mov     rcx,    [rdi + 24]   // B
    mov     rax,    [rdi + 16]   // A

    mov     rbx,    [rdi + 8]    // k
    mov     r8,    [rdi + 32]   // packing
    test    rbx,    rbx
    jz      {{L}}non_linear_loop

    cmp     r8, 1
    jz      {{L}}q40f32

    cmp     r8, 2
    jz      {{L}}q40f16

    cmp     r8, 3
    jz      {{L}}f16f16

    cmp     r8, 4
    jz      {{L}}f16f32

    cmp     r8, 5
    jz      {{L}}f32f16

{{align}} 16
{{L}}main_loop_packed_packed:
    vbroadcastss    ymm15,  dword ptr [rcx]

    vmovaps     ymm8, [rax]
    vmovaps     ymm9, [rax + 32]
    vmovaps     ymm10, [rax + 64]
    vmovaps     ymm11, [rax + 96]

    vfmadd231ps     ymm0, ymm15, ymm8
    vfmadd231ps     ymm1, ymm15, ymm9
    vfmadd231ps     ymm2, ymm15, ymm10
    vfmadd231ps     ymm3, ymm15, ymm11

    add             rcx, 4
	add             rax, 128
    sub             rbx, 1
    jnz             {{L}}main_loop_packed_packed

    jmp             {{L}}non_linear_loop

{% if msvc %}
{{L}}q40f32_mask:
    {{long}} 0F0F0F0Fh
{{L}}q40f32_eight:
    {{long}} 08h
{% else %}
{{L}}q40f32_mask:
    {{long}} 0x0F0F0F0F
{{L}}q40f32_eight:
    {{long}} 8
{% endif %}

{{L}}q40f32:
    // ymm0-3: acc
    // ymm4-7: scales
    // ymm13: 8
    // ymm14: mask
    // ymm15: b value
    vbroadcastss    ymm14, dword ptr [{{offset}} {{L}}q40f32_mask]
    vbroadcastss    ymm13, dword ptr [{{offset}} {{L}}q40f32_eight]

{{L}}q40f32_outerloop:
    // scales
    vmovaps         xmm4, [rax]
    vmovaps         xmm5, [rax + 16]
    vmovaps         xmm6, [rax + 32]
    vmovaps         xmm7, [rax + 48]
    vcvtph2ps       ymm4, xmm4
    vcvtph2ps       ymm5, xmm5
    vcvtph2ps       ymm6, xmm6
    vcvtph2ps       ymm7, xmm7
    add             rax, 64

    mov             rdx, 32

{{L}}q40f32_innerloop:
    vbroadcastss    ymm15, dword ptr [rcx]
    vmovaps         xmm8, [rax]            // 32 nibbles

    vpand           xmm10, xmm8, xmm14      // 16 bytes

    vpmovzxbd       ymm9, xmm10            // 8 u32

    vpermilpd       xmm10, xmm10, 1        // swap 64bit halves
    vpmovzxbd       ymm10, xmm10            // 8 u32

    vpsrlw          xmm8, xmm8, 4
    vpand           xmm12, xmm8, xmm14      // 16 bytes
    vpmovzxbd       ymm11, xmm12            // 8 u32
    vpermilpd       xmm12, xmm12, 1        // swap 64bit halves
    vpmovzxbd       ymm12, xmm12            // 8 u32

    vpsubd          ymm9, ymm9, ymm13
    vpsubd          ymm10, ymm10, ymm13
    vpsubd          ymm11, ymm11, ymm13
    vpsubd          ymm12, ymm12, ymm13

    vcvtdq2ps       ymm9, ymm9
    vcvtdq2ps       ymm10, ymm10
    vcvtdq2ps       ymm11, ymm11
    vcvtdq2ps       ymm12, ymm12

    vmulps          ymm9, ymm9, ymm4
    vmulps          ymm10, ymm10, ymm5
    vmulps          ymm11, ymm11, ymm6
    vmulps          ymm12, ymm12, ymm7

    vfmadd231ps     ymm0, ymm15, ymm9
    vfmadd231ps     ymm1, ymm15, ymm10
    vfmadd231ps     ymm2, ymm15, ymm11
    vfmadd231ps     ymm3, ymm15, ymm12

    add             rax, 16
    add             rcx, 4
    sub             rdx, 1
    jnz             {{L}}q40f32_innerloop

    sub             rbx, 32
    jnz             {{L}}q40f32_outerloop

    jmp             {{L}}non_linear_loop

{{L}}q40f16:
    // ymm0-3: acc
    // ymm4-7: scales
    // ymm13: 8
    // ymm14: mask
    // ymm15: b value
    vbroadcastss    ymm14, dword ptr [{{offset}} {{L}}q40f32_mask]
    vbroadcastss    ymm13, dword ptr [{{offset}} {{L}}q40f32_eight]

{{L}}q40f16_outerloop:
    // scales
    vmovaps         xmm4, [rax]
    vmovaps         xmm5, [rax + 16]
    vmovaps         xmm6, [rax + 32]
    vmovaps         xmm7, [rax + 48]
    vcvtph2ps       ymm4, xmm4
    vcvtph2ps       ymm5, xmm5
    vcvtph2ps       ymm6, xmm6
    vcvtph2ps       ymm7, xmm7
    add             rax, 64

    mov             rdx, 32

{{L}}q40f16_innerloop:
    vpbroadcastw    ymm15, word ptr [rcx]
    vcvtph2ps       ymm15, xmm15

    vmovaps         xmm8, [rax]            // 32 nibbles

    vpand           xmm10, xmm8, xmm14      // 16 bytes

    vpmovzxbd       ymm9, xmm10            // 8 u32

    vpermilpd       xmm10, xmm10, 1        // swap 64bit halves
    vpmovzxbd       ymm10, xmm10            // 8 u32

    vpsrlw          xmm8, xmm8, 4
    vpand           xmm12, xmm8, xmm14      // 16 bytes
    vpmovzxbd       ymm11, xmm12            // 8 u32
    vpermilpd       xmm12, xmm12, 1        // swap 64bit halves
    vpmovzxbd       ymm12, xmm12            // 8 u32

    vpsubd          ymm9, ymm9, ymm13
    vpsubd          ymm10, ymm10, ymm13
    vpsubd          ymm11, ymm11, ymm13
    vpsubd          ymm12, ymm12, ymm13

    vcvtdq2ps       ymm9, ymm9
    vcvtdq2ps       ymm10, ymm10
    vcvtdq2ps       ymm11, ymm11
    vcvtdq2ps       ymm12, ymm12

    vmulps          ymm9, ymm9, ymm4
    vmulps          ymm10, ymm10, ymm5
    vmulps          ymm11, ymm11, ymm6
    vmulps          ymm12, ymm12, ymm7

    vfmadd231ps     ymm0, ymm15, ymm9
    vfmadd231ps     ymm1, ymm15, ymm10
    vfmadd231ps     ymm2, ymm15, ymm11
    vfmadd231ps     ymm3, ymm15, ymm12

    add             rax, 16
    add             rcx, 2
    sub             rdx, 1
    jnz             {{L}}q40f16_innerloop

    sub             rbx, 32
    jnz             {{L}}q40f16_outerloop

    jmp             {{L}}non_linear_loop

{{L}}f16f16:
{{align}} 16
    vpbroadcastw    ymm15, word ptr [rcx]

    vmovaps     xmm4, [rax]
    vmovaps     xmm5, [rax + 16]
    vmovaps     xmm6, [rax + 32]
    vmovaps     xmm7, [rax + 48]

    vcvtph2ps       ymm15, xmm15
    vcvtph2ps       ymm4, xmm4
    vcvtph2ps       ymm5, xmm5
    vcvtph2ps       ymm6, xmm6
    vcvtph2ps       ymm7, xmm7

    vfmadd231ps     ymm0, ymm15, ymm4
    vfmadd231ps     ymm1, ymm15, ymm5
    vfmadd231ps     ymm2, ymm15, ymm6
    vfmadd231ps     ymm3, ymm15, ymm7

    add             rcx, 2
	add             rax, 64
    sub             rbx, 1
    jnz             {{L}}f16f16

    jmp             {{L}}non_linear_loop

{{L}}f32f16:
{{align}} 16
    vpbroadcastw    ymm15, word ptr [rcx]

    vmovaps     ymm4, [rax]
    vmovaps     ymm5, [rax + 32]
    vmovaps     ymm6, [rax + 64]
    vmovaps     ymm7, [rax + 96]

    vcvtph2ps       ymm15, xmm15

    vfmadd231ps     ymm0, ymm15, ymm4
    vfmadd231ps     ymm1, ymm15, ymm5
    vfmadd231ps     ymm2, ymm15, ymm6
    vfmadd231ps     ymm3, ymm15, ymm7

    add             rcx, 2
	add             rax, 128
    sub             rbx, 1
    jnz             {{L}}f32f16

    jmp             {{L}}non_linear_loop

{{L}}f16f32:
{{align}} 16
    vbroadcastss    ymm15,  dword ptr [rcx]

    vmovaps     xmm4, [rax]
    vmovaps     xmm5, [rax + 16]
    vmovaps     xmm6, [rax + 32]
    vmovaps     xmm7, [rax + 48]

    vcvtph2ps       ymm4, xmm4
    vcvtph2ps       ymm5, xmm5
    vcvtph2ps       ymm6, xmm6
    vcvtph2ps       ymm7, xmm7

    vfmadd231ps     ymm0, ymm15, ymm4
    vfmadd231ps     ymm1, ymm15, ymm5
    vfmadd231ps     ymm2, ymm15, ymm6
    vfmadd231ps     ymm3, ymm15, ymm7

    add             rcx, 4
	add             rax, 64
    sub             rbx, 1
    jnz             {{L}}f16f32

    jmp             {{L}}non_linear_loop


{% include "fma_mmm_f32_scalars.tmpliq" from:0, to:3, type:"f32" %}
{% include "fma_mmm_f32_per_rows.tmpliq" mr:32, from:0, to:3, type:"f32" %}
{% include "fma_mmm_f32_per_cols.tmpliq" mr:32, from:0, to:3, type:"f32" %}
{% include "fma_mmm_load_tile.tmpliq" from:0, to:3 %}

{{L}}add_unicast:
    mov     r10,    [rdi + 8]           // c ptr
    mov     rsi,    [rdi + 16]          // row stride

	cmp rsi, 4
	jne {{L}}add_unicast_generic

    {% for row in (0..3) %}
        vaddps ymm{{row}}, ymm{{row}}, [ r10 + {{row|times:32}} ]
    {% endfor %}
    jmp    {{L}}non_linear_loop


    jmp    {{L}}non_linear_loop

{{L}}add_unicast_generic:
    mov     eax,    0
{% for i in (0..3) %}
    pinsrd  xmm14, eax, {{i}}
    add     eax,    esi
{% endfor %}
{% for i in (0..3) %}
    pinsrd  xmm15, eax, {{i}}
    add     eax,    esi
{% endfor %}

    vperm2f128      ymm14,  ymm14, ymm15,         32 // ymm14 <- xmm14::xmm15

{% for i in (0..3) %}
    vpcmpeqd        ymm15,  ymm15, ymm15
    vgatherdps      ymm12,  [ r10 + ymm14 ], ymm15

    vaddps          ymm{{i}},   ymm{{i}},   ymm12
    lea             r10, [ r10 + rsi * 8 ]
{% endfor %}

    jmp    {{L}}non_linear_loop

{{L}}add_row_col_products:
    mov             rax, [ rdi + 8 ]
    mov             rbx, [ rdi + 16 ]

    vbroadcastss    ymm14, dword ptr [rbx]

{% for i in (0..3) %}
    vmovups         ymm12,  [rax + {{i|times:32}}]
    vfmadd231ps     ymm{{i}}, ymm12, ymm14
{% endfor %}
    jmp    {{L}}non_linear_loop

{{L}}store:
    mov     r8,     [rdi + 8]           // c ptr
    mov     rsi,    [rdi + 16]          // row stride
    mov     r11,    [rdi + 32]          // item size

    cmp     r11, 2
    je      {{L}}store_f16

	cmp rsi, 4
	jne {{L}}store_generic

	{% for row in (0..3) %}
        vmovups [r8 + {{row|times:32}}], ymm{{row}}
    {% endfor %}

    jmp     {{L}}non_linear_loop

{{L}}store_generic:

    {% for vec in (0..3) %}
        {% for half in (0..1) %}
            {% if half == 0 %}
                movaps xmm9, xmm{{vec}}
            {% else %}
                vperm2f128 ymm9, ymm{{vec}}, ymm{{vec}}, 1
            {% endif %}
            {% for row in (0..3) %}
                vextractps  dword ptr [r8], xmm9, {{row}}
                add         r8, rsi
            {% endfor %}
        {% endfor %}
    {% endfor %}

    jmp    {{L}}non_linear_loop

{{L}}store_f16:

    vcvtps2ph   xmm0, ymm0, 0
    vcvtps2ph   xmm1, ymm1, 0
    vcvtps2ph   xmm2, ymm2, 0
    vcvtps2ph   xmm3, ymm3, 0

    cmp         rsi, 2
	jne {{L}}store_generic_f16

	{% for row in (0..3) %}
        vmovups [r8 + {{row|times:16}}], xmm{{row}}
    {% endfor %}

    jmp     {{L}}non_linear_loop
    
{{L}}store_generic_f16:

    {% for vec in (0..3) %}
        {% for row in (0..7) %}
            pextrw      word ptr [r8], xmm{{vec}}, {{row}}
            add         r8, rsi
        {% endfor %}
    {% endfor %}

    jmp     {{L}}non_linear_loop

{% include "postamble.tmpliq" type:"f32", size:"32x1", suffix:suffix, G:G, L:L %}
